# analysis.py
import torch
import torch.nn.functional as F
import numpy as np
import os
import matplotlib.pyplot as plt
from utils import gaussian_gram_matrix, compute_von_neumann_entropy_from_density, compute_joint_entropy
from transformers import AutoTokenizer

def analyze_batch(hidden_states_z, hidden_states_y, input_ids, context_lengths, lm_head_weight, sigma=1.0, save_dir=None):
    # Number of samples.
    N = hidden_states_z[0].size(0)

    # -----------------------
    # Compute X: final token from embedding layer (hidden_states_z[0])
    X = torch.stack([hidden_states_z[0][i, cl - 1, :] for i, cl in enumerate(context_lengths)])  # [N, D]
    X = X.unsqueeze(1)  # reshape to [N, 1, D]
    X_norm = F.normalize(X, p=2, dim=-1)
    gram_x = gaussian_gram_matrix(X_norm, sigma)  # [N, N]
    gram_x_norm = gram_x / torch.trace(gram_x)
    H_X = compute_von_neumann_entropy_from_density(gram_x_norm)

    # -----------------------
    # Compute X_previous_all_token from embedding layer
    max_len_prev = max([cl - 1 for cl in context_lengths])
    X_prev_list = []
    for i, cl in enumerate(context_lengths):
        if cl - 1 > 0:
            tokens = hidden_states_z[0][i, :cl - 1, :]
        else:
            tokens = hidden_states_z[0][i, :1, :] * 0
        pad_len = max_len_prev - (cl - 1)
        if pad_len > 0:
            padding = torch.zeros((pad_len, tokens.size(-1)), device=tokens.device)
            tokens = torch.cat([tokens, padding], dim=0)
        X_prev_list.append(tokens)
    X_prev = torch.stack(X_prev_list, dim=0)
    X_prev_norm = F.normalize(X_prev, p=2, dim=-1)
    gram_x_prev = gaussian_gram_matrix(X_prev_norm, sigma)
    gram_x_prev_norm = gram_x_prev / torch.trace(gram_x_prev)
    H_X_prev = compute_von_neumann_entropy_from_density(gram_x_prev_norm)

    # -----------------------
    # Compute Y: true label token embedding
    device = input_ids.device
    context_lengths_tensor = torch.tensor(context_lengths, device=device)
    target_token_ids = input_ids[torch.arange(N), context_lengths_tensor]

    Y = torch.stack([hidden_states_y[0][i, cl, :] for i, cl in enumerate(context_lengths)])
    Y = Y.unsqueeze(1)
    Y_norm = F.normalize(Y, p=2, dim=-1)
    gram_y = gaussian_gram_matrix(Y_norm, sigma)
    gram_y_norm = gram_y / torch.trace(gram_y)
    H_Y = compute_von_neumann_entropy_from_density(gram_y_norm)

    # Save computed entropies in the overall results.
    H_joint_XY = compute_joint_entropy(gram_x, gram_y)
    I_XY = H_X + H_Y - H_joint_XY

    H_joint_XprevY = compute_joint_entropy(gram_x_prev, gram_y)
    I_XprevY = H_X_prev + H_Y - H_joint_XprevY
    results = {"H_X": H_X, "H_Y": H_Y, "H_X_prev": H_X_prev, "H_joint_XY": H_joint_XY, "H_joint_XprevY": H_joint_XprevY, "I_XY": I_XY, "I_XprevY": I_XprevY}

    # -----------------------
    # Per-layer analysis.
    num_layers = len(hidden_states_z) - 1
    layer_results = {}
    for l in range(1, num_layers + 1):
        # Compute Z_final: final token output from layer l.
        Z_final = torch.stack([hidden_states_z[l][i, cl - 1, :] for i, cl in enumerate(context_lengths)])
        Z_final = Z_final.unsqueeze(1)
        Z_final_norm = F.normalize(Z_final, p=2, dim=-1)
        gram_z_final = gaussian_gram_matrix(Z_final_norm, sigma)
        gram_z_final_norm = gram_z_final / torch.trace(gram_z_final)
        H_Z_final = compute_von_neumann_entropy_from_density(gram_z_final_norm)

        # Joint entropy and MI between X and Z_final.
        H_joint_XZ = compute_joint_entropy(gram_x, gram_z_final)
        I_XZ = H_X + H_Z_final - H_joint_XZ

        # Joint entropy and MI between Z_final and Y.
        H_joint_ZY = compute_joint_entropy(gram_z_final, gram_y)
        I_ZY = H_Z_final + H_Y - H_joint_ZY

        layer_results[f"layer_{l}"] = {
            "H_Z_final": H_Z_final,
            "H_joint_ZY": H_joint_ZY,
            "I_ZY": I_ZY
        }

    results["layer_results"] = layer_results
    return results

def visualize_information_flow(results, save_dir="figures", epoch_checkpoint=None):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    layers = sorted(results["layer_results"].keys(), key=lambda x: int(x.split("_")[1]))
    layer_nums = [int(x.split("_")[1]) for x in layers]
    I_ZY_values = [results["layer_results"][layer]["I_ZY"] for layer in layers]
    
    # Retrieve I(X;Y)
    I_XY = results.get("I_XY", None)
    I_XprevY = results.get("I_XprevY", None)

    plt.figure()
    plt.plot(layer_nums, I_ZY_values, marker='^', linestyle='-', label="I(Z;Y)")
    
    # Plot I(X;Y) as dashed horizontal line
    if I_XY is not None:
        plt.axhline(y=I_XY, linestyle='--', color='black', label="I(X;Y)")
        plt.axhline(y=I_XprevY, linestyle='--', color='grey', label="I(Xprev;Y)")

    plt.xlabel("Layer")
    plt.ylabel("Mutual Information")
    plt.title("plot")
    plt.legend()
    plt.grid(True)

    save_path = os.path.join(save_dir, f"path/to/your/file")
    plt.savefig(save_path)
    print(f"Mutual information plot saved to {save_path}")
    plt.close()